"""
Code adapted from https://github.com/if-loops/selective-synaptic-dampening/tree/main/src
https://arxiv.org/abs/2308.07707
"""

import random
import os

from typing import Tuple, List
import sys
import argparse
import time
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, ConcatDataset, dataset
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import models
from unlearn import *
from utils import *
import forget
import datasets
import models
import conf
from training_utils import *
import pandas as pd 
from torch.utils.data import Subset
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import numpy as np
import os
from PIL import Image
import torchvision.transforms as transforms
from collections import Counter

import time

def transfer_percent_from_forget_to_retain(trainset, forget_indices, retain_indices, percent_to_transfer, num_classes=10, dataset_name = "Cifar10"):
    """
    Moves the first X% (per class) of forget_indices into retain_indices.
    Returns updated forget_indices, retain_indices.
    """
    if dataset_name == 'MUCAC':
        targets = np.array(trainset.labels)
    else:
        targets = np.array(trainset.targets)
    updated_forget = []
    updated_retain = list(retain_indices)  # make a copy

    forget_by_class = {c: [] for c in range(num_classes)}
    for idx in forget_indices:
        label = targets[idx]
        forget_by_class[label].append(idx)

    for class_label in range(num_classes):
        class_forget = forget_by_class[class_label]
        class_forget = sorted(class_forget)  # ensure consistent order

        num_to_transfer = int(len(class_forget) * percent_to_transfer)
        to_transfer = class_forget[:num_to_transfer]
        to_keep = class_forget[num_to_transfer:]

        updated_retain.extend(to_transfer)
        updated_forget.extend(to_keep)

    # Print updated class distributions
    updated_forget_labels = [targets[i] for i in updated_forget]
    updated_retain_labels = [targets[i] for i in updated_retain]

    print("\nUpdated class distribution:")
    print("Retain set:")
    for c in range(num_classes):
        print(f"  Class {c}: {updated_retain_labels.count(c)}")
    print("Forget set:")
    for c in range(num_classes):
        print(f"  Class {c}: {updated_forget_labels.count(c)}")

    return updated_forget, updated_retain
  
def load_forget_retain_indices(trainset, dataset_name, seed, forget_per_class, csv_path="split_indices.csv", num_classes=10):
    """
    Load forget and retain indices for a given seed from CSV.
    Returns: forget_indices, retain_indices
    """
    df = pd.read_csv(csv_path)
    df_seed = df[
        (df['seed'] == seed) &
        (df['dataset'] == dataset_name) &
        (df['forget_per_class'] == forget_per_class)
    ]

    forget_indices = df_seed[df_seed['split'] == 'forget']['index'].tolist()
    retain_indices = df_seed[df_seed['split'] == 'retain']['index'].tolist()
    
    if dataset_name == 'MUCAC':
        forget_labels = [trainset.labels[i] for i in forget_indices]
        retain_labels = [trainset.labels[i] for i in retain_indices]
    else:
        forget_labels = [trainset.targets[i] for i in forget_indices]
        retain_labels = [trainset.targets[i] for i in retain_indices]

    print(f"\nRetain class distribution for seed {seed}:")
    if dataset_name == "Mnist":
       retain_dist = Counter([label.item() for label in retain_labels])
    else: 
      retain_dist = Counter(retain_labels)
    for c in range(num_classes): 
        print(f"Class {c}: {retain_dist[c]}")

    print(f"\nForget class distribution for seed {seed}:")
    if dataset_name == "Mnist":
       forget_dist = Counter([label.item() for label in forget_labels])
    else: 
      forget_dist = Counter(forget_labels)
    for c in range(num_classes): 
        print(f"Class {c}: {forget_dist[c]}")
        
    return forget_indices, retain_indices
      
if __name__ == '__main__':
    """
    Get Args
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("-net", default="ResNet18", type=str,  help="net type")
    
    # Set default path relative to this script
    current_dir = os.path.dirname(os.path.abspath(__file__))
    default_weight_path = os.path.join(current_dir, "checkpoint", "ResNet18", "ResNet18-Cifar10-20-best.pth")

    parser.add_argument(
        "-weight_path",
        type=str,
        default=default_weight_path,
        help="Path to model weights. If you need to train a new model use pretrain_model.py",
    )
    parser.add_argument(
        "-dataset",
        type=str,
        default = "Cifar10",
        nargs="?",
        choices=["Cifar10", "Cifar20", "Cifar100", "PinsFaceRecognition", "Mnist", "MUCAC"],
        help="dataset to train on",
    )
    parser.add_argument("-classes", type=int, default=10,  help="number of classes")
    parser.add_argument("-gpu", action="store_true", default=True, help="use gpu or not")
    parser.add_argument("-b", type=int, default=256, help="batch size for dataloader")
    parser.add_argument("-warm", type=int, default=1, help="warm up training phase")
    parser.add_argument("-lr", type=float, default=0.1, help="initial learning rate")
    parser.add_argument(
        "-method",
        type=str,
        nargs="?",
        default = "baseline",
        choices=
        [
            "baseline",
            "retrain",
            "finetune",
            "teacher",
            "amnesiac",
            "FisherForgetting",
            "ssdtuning"

        ],
        help="select unlearning method from choice set",
    )
    parser.add_argument(
        "-forget_perc", type=float, default=0.1,  help="Percentage of trainset to forget"
    )
    parser.add_argument(
        "-epochs", type=int, default=1, help="number of epochs of unlearning method to use"
    )
    parser.add_argument(
        "-ret_perc", type=int, default=0, help="percentage from forget set to move to retrain"
    )    
    parser.add_argument("-seed", type=int, default=7, help="seed for runs")
    args = parser.parse_args()

    # Set seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    
    batch_size = args.b

    # get network
    net = getattr(models, args.net)(num_classes=args.classes)
    net.load_state_dict(torch.load(args.weight_path))

    unlearning_teacher = getattr(models, args.net)(num_classes=args.classes)

    if args.gpu:
        net = net.cuda()
        unlearning_teacher = unlearning_teacher.cuda()


    if args.dataset == "PinsFaceRecognition":
        root = "105_classes_pins_dataset"
    elif args.dataset == "MUCAC":
        root = "./data/MUCAC"
    else:
        root = "./data"

    img_size = 224 if args.net == "ViT" else 32
    trainset = getattr(datasets, args.dataset)(
        root=root, download=True, train=True, unlearning=True, img_size=img_size
    )
    validset = getattr(datasets, args.dataset)(
        root=root, download=True, train=False, unlearning=True, img_size=img_size
    )

    trainloader = DataLoader(trainset, num_workers=4, batch_size=args.b, shuffle=True)
    validloader = DataLoader(validset, num_workers=4, batch_size=args.b, shuffle=False)

    forget_indices, retain_indices = load_forget_retain_indices(
        trainset=trainset,
        dataset_name='Cifar10',
        seed=args.seed,
        forget_per_class=500, #change cifar 10 (500),   cifar 20 (50), MNIST (600), MUCAC (527)
        csv_path="splits/split_indices_Cifar10.csv",
        num_classes=10
    )
    
    forget_indices, retain_indices = transfer_percent_from_forget_to_retain(
       trainset=trainset,
        forget_indices=forget_indices,
        retain_indices=retain_indices,
        percent_to_transfer=(args.ret_perc/100),  # Move x% of forget set per class to retain set
        num_classes=10,
        dataset_name='Cifar10',
    )
    retainset = getattr(datasets, args.dataset)(
        root=root,
        download=True,
        train=True,
        unlearning=False,
        img_size=img_size,
        indices=retain_indices
    )
    forgetset = getattr(datasets, args.dataset)(
        root=root,
        download=True,
        train=True,
        unlearning=False,
        img_size=img_size,
        indices=forget_indices
    )
    
    retain_train_dl = DataLoader(retainset, batch_size=args.b, shuffle=True)  
    forget_train_dl = DataLoader(forgetset, batch_size=args.b, shuffle=False)    

    forget_valid_dl = forget_train_dl
    retain_valid_dl = retain_train_dl
    

    # Change alpha here as described in ssd the paper
    model_size_scaler = 1
    if args.net == "ViT":
        model_size_scaler = 1
    else:
        model_size_scaler = 1


    full_train_dl = DataLoader(
        ConcatDataset((retain_train_dl.dataset, forget_train_dl.dataset)),
        batch_size=batch_size,
    )

    kwargs = {
        "model": net,
        "seed": args.seed,
        "unlearning_teacher": unlearning_teacher,
        "train_dl": trainloader,
        "retain_train_dl": retain_train_dl,
        "retain_valid_dl": retain_valid_dl,
        "forget_train_dl": forget_train_dl,
        "forget_valid_dl": forget_valid_dl,
        "full_train_dl": full_train_dl,
        "valid_dl": validloader,
        "dampening_constant": 1,
        "selection_weighting": 10 * model_size_scaler,
        "num_classes": args.classes,
        "dataset_name": args.dataset,
        "device": "cuda" if args.gpu else "cpu",
        "model_name": args.net,
      "ret_perc": args.ret_perc
    }

    # Start time tracking
    start = time.time()

    # Perform the unlearning method
    testacc, retainacc, zrf, mia, mia_forget_retain, mia_forget_test, mia_retain_test, mia_train_test, d_f = getattr(forget, args.method)( 
        **kwargs
    )

    # End time tracking
    end = time.time()
    time_elapsed = end - start

    # Print the results
    print(f"Test Accuracy: {testacc}")
    print(f"Retain Accuracy: {retainacc}")
    print(f"Zero-Retain Forget (ZRF): {zrf}")
    print(f"Membership Inference Attack (MIA): {mia}")
    print(f"Forget vs Retain Membership Inference Attack (MIA): {mia_forget_retain}")
    print(f"Forget vs Test Membership Inference Attack (MIA): {mia_forget_test}")
    print(f"Test vs Retain Membership Inference Attack (MIA): {mia_retain_test}")
    print(f"Train vs Test Membership Inference Attack (MIA): {mia_train_test}")
    print(f"Forget Set Accuracy (Df): {d_f}")
    print(f"Method Execution Time: {time_elapsed:.2f} seconds")
    
